{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "546fa208-c19c-405b-81b1-cf99fcb4a830",
   "metadata": {
    "tags": []
   },
   "source": [
    "# UCI Adult Dataset or Census Income\n",
    "\n",
    "This is a very popular ML task, with tabular data. The objective is to predict whether income exceeds $50K/yr based on census data. \n",
    "Also known as \"Census Income\" dataset.\n",
    "\n",
    "The data is old and biased on different ways ... but it can be used opaquely for ML experimentation.\n",
    "\n",
    "The source code of this example in one piece can be seen in the demo under [.../examples/adult/demo/](https://github.com/gomlx/gomlx/tree/main/examples/adult/demo).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ab071e3-227f-4c77-b864-2b7e5a5ccbdf",
   "metadata": {},
   "source": [
    "## Environment Set Up\n",
    "\n",
    "Let's set up `go.mod` to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.\n",
    "\n",
    "If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.\n",
    "\n",
    "Notice the directory `${HOME}/Projects/gomlx` is where the GoMLX code is copied by default in [its Docker](https://hub.docker.com/repository/docker/janpfeifer/gomlx_jupyterlab/general).\n",
    "\n",
    "For this example we are forcing it to use the CPU backend, even if there is a GPU available -- it is such a small model, not worth going to the GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9ba46cb6-059d-4387-876c-665f778f2447",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t- Added replace rule for module \"github.com/gomlx/gomlx\" to local directory \"/home/janpf/Projects/gomlx\".\n",
      "\t- Added replace rule for module \"github.com/gomlx/gopjrt\" to local directory \"/home/janpf/Projects/gopjrt\".\n"
     ]
    }
   ],
   "source": [
    "!*rm -f go.work && go work init && go work use . \"${HOME}/Projects/gomlx\" \"${HOME}/Projects/gopjrt\"\n",
    "%goworkfix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78cf991f-4fa5-4bae-8e49-ca67fb8dbd5b",
   "metadata": {},
   "source": [
    "## Data Preparation\n",
    "\n",
    "GoMLX provides [a simple `adult` library](https://pkg.go.dev/github.com/gomlx/gomlx/examples/adult) to facilitate downdoaling and preprocessing the data. Data is available in [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Adult).\n",
    "\n",
    "After downloading the data and validating the checksum (both training and testing), it generates the quantiles for the continuous features, and the vocabularies for the categorical features. It saves all this info for faster restart later in a binary file. So this won't be necessary a second time.\n",
    "\n",
    "The quantiles are used to calibrate the values, using a piece-wise-lienar calibration, very good for these things. See [`layers.PieceWiseLinearCalibration` documentation](https://pkg.go.dev/github.com/gomlx/gomlx@v0.1.0/ml/layers#PieceWiseLinearCalibration).\n",
    "\n",
    "We create a flag `--data` to define the directory where to save the intermediary files: downloaded and preprocessed datasets.\n",
    "In this examle we set it to `~/work/uci-adult`. Verbosity can be contolled with the `--verbosity` flag. \n",
    "\n",
    "We set default in Go for these flags, but they can easily be reset for a new run by providing them after the `%%` Jupyter kernel meta-command -- in indicates that the subsequent lines should be put in to a `func main`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9418be86-70bc-41ad-9e0a-3a5e6bf52b51",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Sample Categorical: (24.08% positive ratio, 23.86% weighted positive ratio)\n",
      "\tRow 0:\t[7 10 5 1 2 5 2 39]\n",
      "\tRow 1:\t[6 10 3 4 1 5 2 39]\n",
      "\tRow 2:\t[4 12 1 6 2 5 2 39]\n",
      "\t...\n",
      "\tRow 32558:\t[4 12 7 1 5 5 1 39]\n",
      "\tRow 32559:\t[4 12 5 1 4 5 2 39]\n",
      "\tRow 32560:\t[5 12 3 4 6 5 1 39]\n",
      "\n",
      "Sample Continuous:\n",
      "\tRow 0:\t[39 13 2174 0 40]\n",
      "\tRow 1:\t[50 13 0 0 13]\n",
      "\tRow 2:\t[38 9 0 0 40]\n",
      "\t...\n",
      "\tRow 32558:\t[58 9 0 0 40]\n",
      "\tRow 32559:\t[22 9 0 0 20]\n",
      "\tRow 32560:\t[52 9 15024 0 40]\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"flag\"\n",
    "    \n",
    "    \"github.com/gomlx/gomlx/examples/adult\"\n",
    ")\n",
    "\n",
    "var (\n",
    "    flagDataDir       = flag.String(\"data\", \"~/work/uci-adult\", \"Directory to save and load downloaded and generated dataset files.\")\n",
    "    flagVerbosity     = flag.Int(\"verbosity\", 0, \"Level of verbosity, the higher the more verbose.\")\n",
    "    flagForceDownload = flag.Bool(\"force_download\", false, \"Force re-download of Adult dataset files.\")\n",
    "    flagNumQuantiles  = flag.Int(\"quantiles\", 100, \"Max number of quantiles to use for numeric features, used during piece-wise linear calibration. It will only use unique values, so if there are fewer variability, fewer quantiles are used.\")\n",
    ")\n",
    "\n",
    "%% --verbosity=2\n",
    "adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fdab7255-2224-496a-9ec7-75927150f9d7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 170M\n",
      "-rw-r--r-- 1 janpf janpf 3.8M Jul 20  2024 adult.data\n",
      "-rw-r--r-- 1 janpf janpf 2.0M Jul 20  2024 adult.test\n",
      "-rw-r--r-- 1 janpf janpf 1.3M Jul 20  2024 adult_data-100_quantiles.bin\n",
      "drwxr-x--- 2 janpf janpf 4.0K Oct 11 11:12 base_model\n",
      "drwxr-xr-x 2 janpf janpf 4.0K Jun  4  2009 cifar-10-batches-bin\n",
      "-rw-r--r-- 1 janpf janpf 163M Aug 15  2024 cifar-10-binary.tar.gz\n",
      "drwxr-x--- 2 janpf janpf 4.0K Aug  2  2024 fnn\n",
      "drwxr-x--- 2 janpf janpf 4.0K Jul 21  2024 kan_baseline\n",
      "drwxr-x--- 2 janpf janpf 4.0K Oct 11 11:12 kan_model\n",
      "drwxr-x--- 2 janpf janpf 4.0K Feb 23  2025 test\n",
      "drwxr-x--- 2 janpf janpf 4.0K Jun 18 05:14 test_model\n"
     ]
    }
   ],
   "source": [
    "!ls -lh ~/work/uci-adult"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f1cd493-a81f-4c5f-9004-d21e95103f8e",
   "metadata": {},
   "source": [
    "## Hyperparameters\n",
    "\n",
    "This sets the the superset of hyperparameters with their default values that can be used by the model, by setting them in the context.\n",
    "\n",
    "See the [`demo/main.go`](https://github.com/gomlx/gomlx/blob/main/examples/adult/demo/main.go) file for how to add a flag to allow setting them from the command line."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "31d03e06-f375-4a68-8705-c9d3b6af81ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tbatch_size=17\n",
      "\tfnn_num_hidden_layers=12\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/context\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers/fnn\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers/kan\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/layers/regularizers\"\n",
    "    \"github.com/gomlx/gomlx/ui/commandline\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train/optimizers\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train/optimizers/cosineschedule\"\n",
    "\t\"github.com/janpfeifer/must\"    \n",
    ")\n",
    "\n",
    "// settings is bound to a \"-set\" flag to be used to set context hyperparameters.\n",
    "var settings = commandline.CreateContextSettingsFlag(createDefaultContext(), \"set\")\n",
    "\n",
    "func createDefaultContext() *context.Context {\n",
    "\tctx := context.New()\n",
    "\tctx.RngStateReset()\n",
    "\tctx.SetParams(map[string]any{\n",
    "\t\t\"train_steps\":     5000,\n",
    "\t\t\"batch_size\":      128,\n",
    "\t\t\"eval_batch_size\":      1000,\n",
    "\t\t\"plots\":           true,\n",
    "        \"num_checkpoints\": 3,\n",
    "\n",
    "\t\toptimizers.ParamOptimizer:           \"adam\",\n",
    "\t\toptimizers.ParamLearningRate:        0.001,\n",
    "\t\toptimizers.ParamAdamEpsilon:         1e-7,\n",
    "\t\toptimizers.ParamAdamDType:           \"\",\n",
    "\t\tcosineschedule.ParamPeriodSteps:     0,\n",
    "\t\tactivations.ParamActivation:         \"sigmoid\",\n",
    "\t\tlayers.ParamDropoutRate:             0.0,\n",
    "\t\tregularizers.ParamL2:                1e-5,\n",
    "\t\tregularizers.ParamL1:                1e-5,\n",
    "\n",
    "\t\t// FNN network parameters:\n",
    "\t\tfnn.ParamNumHiddenLayers: 1,\n",
    "\t\tfnn.ParamNumHiddenNodes:  4,\n",
    "\t\tfnn.ParamResidual:        true,\n",
    "\t\tfnn.ParamNormalization:   \"layer\",\n",
    "\n",
    "\t\t// KAN network parameters:\n",
    "\t\t\"kan\":                       false, // Enable kan\n",
    "\t\tkan.ParamNumControlPoints:   20,    // Number of control points\n",
    "\t\tkan.ParamNumHiddenNodes:     4,\n",
    "\t\tkan.ParamNumHiddenLayers:    1,\n",
    "\t\tkan.ParamBSplineDegree:      2,\n",
    "\t\tkan.ParamBSplineMagnitudeL1: 1e-5,\n",
    "\t\tkan.ParamBSplineMagnitudeL2: 0.0,\n",
    "\t\tkan.ParamDiscrete:           false,\n",
    "\t\tkan.ParamDiscreteSoftness:   0.1,\n",
    "\t})\n",
    "\treturn ctx\n",
    "}\n",
    "\n",
    "// contextFromSettings is the default context (createDefaultContext) changed by -set flag.\n",
    "func contextFromSettings() (ctx *context.Context, paramsSet []string) {\n",
    "    ctx = createDefaultContext()\n",
    "    paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))\n",
    "    return\n",
    "}\n",
    "\n",
    "// Let's test that we can set hyperparameters by setting it in the \"-set\" flag:\n",
    "%% -set=\"batch_size=17;fnn_num_hidden_layers=12\"\n",
    "ctx, paramsSet := contextFromSettings()\n",
    "for _, name := range paramsSet {\n",
    "    if value, found := ctx.GetParam(name); found {\n",
    "        fmt.Printf(\"\\t%s=%v\\n\", name, value)\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cdd4b781-3e38-41b2-8621-b4abbf2d6ef3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tbatch_size=17\n",
      "\tfnn_num_hidden_layers=12\n"
     ]
    }
   ],
   "source": [
    "%% -set=\"batch_size=17;fnn_num_hidden_layers=12\"\n",
    "ctx, paramsSet := contextFromSettings()\n",
    "for _, name := range paramsSet {\n",
    "    if value, found := ctx.GetParam(name); found {\n",
    "        fmt.Printf(\"\\t%s=%v\\n\", name, value)\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "028b783b-56f0-457c-b28e-7bb0dd7e94da",
   "metadata": {},
   "source": [
    "### Creating Datasets\n",
    "\n",
    "First we create the GoMLX's `backend`: it's the engine instance (XLA in this case) that \n",
    "compiles and executes our computation graph.\n",
    "\n",
    "It's needed to create tensors that will be fed to the accelearator (GPU or even CPU accelerated code)\n",
    "\n",
    "With that we create the samplers of data that we will use to train and evaluate. They implement \n",
    "GoMLX's `train.Dataset` interface, which is what is used by our training loop to draw batches to\n",
    "train, or our eval loop to draw batches to evaluate.\n",
    "\n",
    "The inputs are 3 tensors: *categorical values*, *continuous values* and *weights*.\n",
    "\n",
    "In the cell below we define `BuildDatasets` and printout some samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f13f771b-7109-49d2-8db5-5f3d5a91d324",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backend: stablehlo - stablehlo:cuda - PJRT \"cuda\" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO]\n",
      "Inputs of batch (size 128):\n",
      "\tcategorical:\n",
      "\t\tFeatures=[workclass education marital-status occupation relationship race sex native-country]\n",
      "\t\tValues: [128][8]int64{\n",
      " {4, 6, 5, ..., 5, 2, 39},\n",
      " {4, 16, 3, ..., 5, 1, 39},\n",
      " {4, 12, 3, ..., 5, 1, 39},\n",
      " ...,\n",
      " {4, 16, 3, ..., 3, 2, 23},\n",
      " {4, 10, 3, ..., 5, 2, 39},\n",
      " {4, 12, 5, ..., 5, 1, 39}}\n",
      "\tcontinuous:\n",
      "\t\tFeatures=[age education-num capital-gain capital-loss hours-per-week]\n",
      "\t\tValues: [128][5]float32{\n",
      " {60, 4, 0, 0, 48},\n",
      " {53, 10, 0, 0, 40},\n",
      " {26, 9, 5013, 0, 40},\n",
      " ...,\n",
      " {34, 10, 0, 0, 52},\n",
      " {38, 13, 0, 0, 40},\n",
      " {30, 9, 0, 0, 40}}\n",
      "\tweights: [128][1]float32{\n",
      " {3.935e+04},\n",
      " {3.468e+05},\n",
      " {1.327e+05},\n",
      " ...,\n",
      " {1.898e+05},\n",
      " {6.078e+05},\n",
      " {1.678e+05}}\n",
      "\n",
      "Labels of batch:\n",
      "\t[128][1]float32{\n",
      " {1},\n",
      " {1},\n",
      " {0},\n",
      " ...,\n",
      " {0},\n",
      " {1},\n",
      " {0}}\n",
      "\n",
      "Labels distributions:\n",
      "\tTrain:\t24.08% positive\n",
      "\tTest:\t23.62% positive\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"flag\"\n",
    "    \"fmt\"\n",
    "    \"io\"\n",
    "\n",
    "    \"github.com/gomlx/gomlx/backends\"\n",
    "    \"github.com/gomlx/gomlx/examples/adult\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train\"\n",
    "    \"github.com/gomlx/gomlx/pkg/core/tensors\"\n",
    "\n",
    "    _ \"github.com/gomlx/gomlx/backends/default\"\n",
    ")\n",
    "\n",
    "// Global backend created an initialization, used everywhere.\n",
    "var backend = backends.MustNew()\n",
    "\n",
    "// BuildDatasets returns 3 `train.Dataset`:\n",
    "// * trainingSampler is an endless random sampler used for training.\n",
    "// * trainingEvalSampler samples through exactly one epoch of the train dataset.\n",
    "// * testEvalSampler samples through exactly one epoch of the test dataset.\n",
    "func BuildDatasets(ctx *context.Context) (trainDS, trainEvalDS, testEvalDS train.Dataset) {\n",
    "\tbatchSize := context.GetParamOr(ctx, \"batch_size\", 128)\n",
    "\tevalBatchSize := context.GetParamOr(ctx, \"eval_batch_size\", 1000)\n",
    "    baseDS := adult.NewDataset(backend, adult.Data.Train, \"batched train\")\n",
    "    trainEvalDS = baseDS.Copy().BatchSize(evalBatchSize, false)\n",
    "    testEvalDS = adult.NewDataset(backend, adult.Data.Test, \"test\").\n",
    "        BatchSize(evalBatchSize, false)\n",
    "    // For training, we shuffle and loop indefinitely.\n",
    "    trainDS = baseDS.BatchSize(batchSize, true).Shuffle().Infinite(true)\n",
    "    return\n",
    "}\n",
    "\n",
    "// PositiveRatio finds out the the ratio of positive labels in the\n",
    "// training and testing data.\n",
    "//\n",
    "// We could do this easily with GoMLX computation model (just `ReduceAllSum`), but\n",
    "// this examples shows it's also ok to mix Go computations.\n",
    "func PositiveRatio(ds train.Dataset) float32 {\n",
    "    ds.Reset()  // Start from beginning.\n",
    "    var sum float32\n",
    "    var count float32\n",
    "    for {\n",
    "        _, _, labels, err := ds.Yield()\n",
    "        if err == io.EOF {\n",
    "            break;\n",
    "        }\n",
    "        if err != nil { panic(err) }\n",
    "        data := tensors.CopyFlatData[float32](labels[0])\n",
    "        for _, value := range data {\n",
    "            sum += value\n",
    "        }\n",
    "        count += float32(len(data))\n",
    "    }\n",
    "    return sum/count\n",
    "}\n",
    "\n",
    "%%\n",
    "fmt.Printf(\"Backend: %s - %s\\n\", backend, backend.Description())\n",
    "\n",
    "adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    \n",
    "ctx, _ := contextFromSettings()\n",
    "trainingDS, trainingEvalDS, testEvalDS := BuildDatasets(ctx)\n",
    "\n",
    "// Take one batch.\n",
    "_, inputs, labels, err := trainingDS.Yield()\n",
    "if err != nil { panic(err) }\n",
    "fmt.Printf(\"Inputs of batch (size %d):\\n\", context.GetParamOr(ctx, \"batch_size\", 0))\n",
    "fmt.Printf(\"\\tcategorical:\\n\\t\\tFeatures=%v\\n\", adult.Data.VocabulariesFeatures)\n",
    "fmt.Printf(\"\\t\\tValues: %s\\n\", inputs[0])\n",
    "fmt.Printf(\"\\tcontinuous:\\n\\t\\tFeatures=%v\\n\", adult.Data.QuantilesFeatures)\n",
    "fmt.Printf(\"\\t\\tValues: %s\\n\", inputs[1])\n",
    "fmt.Printf(\"\\tweights: %s\\n\", inputs[2])\n",
    "fmt.Printf(\"\\nLabels of batch:\\n\\t%s\\n\", labels[0])\n",
    "fmt.Printf(\"\\nLabels distributions:\\n\\tTrain:\\t%.2f%% positive\\n\\tTest:\\t%.2f%% positive\\n\",\n",
    "           PositiveRatio(trainingEvalDS)*100.0, PositiveRatio(testEvalDS)*100.0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54efd046-bbd5-4f7d-94e0-9ed3c64512b1",
   "metadata": {},
   "source": [
    "## Model Definition\n",
    "\n",
    "Lots of hyper-parameter flags, but otherwise a straight forward FNN, using piece-wise linear calibration of the continuous features, and embeddings for the categorical features.\n",
    "\n",
    "> **Note**: building models is a constant checking that shapes are compatible. It's a bit annoying, in particular because shapes are known in runtime only -- no compile time check. GoMLX tries to help providing a stack trace of where errors happen so one can pin-point issues quickly. But often it involves lots of experimentation (more than ordinary Go code).\n",
    ">\n",
    "> Developing with a Noteboook (see [GoNB](https://github.com/janpfeifer/gonb)) or simply a unit test on your\n",
    "> `ModelGraph` function are quick and convenient ways to develop models -- before actually training them.\n",
    "> You can also use shape asserts in the middle of the `ModelGraph`, as we do below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7040e63b-33f1-4ff9-a232-187196c5f9c1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logits shape for batch_size=128: (Float32)[128 1]\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"fmt\"\n",
    "    \"io\"\n",
    "\n",
    "    . \"github.com/gomlx/gomlx/pkg/core/graph\"\n",
    "\n",
    "    \"github.com/gomlx/gomlx/examples/adult\"\n",
    "    \"github.com/gomlx/gomlx/pkg/core/shapes\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/context\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train/optimizers\"\n",
    "    \"github.com/gomlx/gomlx/pkg/ml/train/optimizers/cosineschedule\"\n",
    "    \"github.com/gomlx/gopjrt/dtypes\"\n",
    ")\n",
    "\n",
    "var (\n",
    "    // ModelDType used for the model. Must match RawData Go types.\n",
    "    ModelDType = dtypes.Float32\n",
    "    \n",
    "\n",
    "    // Model hyperparameters.\n",
    "    flagUseCategorical       = flag.Bool(\"use_categorical\", true, \"Use categorical features.\")\n",
    "    flagUseContinuous        = flag.Bool(\"use_continuous\", true, \"Use continuous features.\")\n",
    "    flagTrainableCalibration = flag.Bool(\"trainable_calibration\", true, \"Allow piece-wise linear calibration to adjust outputs.\")\n",
    "    flagEmbeddingDim    = flag.Int(\"embedding_dim\", 8, \"Default embedding dimension for categorical values.\")\n",
    ")\n",
    "\n",
    "\n",
    "// ModelGraph outputs the logits (not the probabilities). The parameter inputs should contain 3 tensors:\n",
    "//\n",
    "// - categorical inputs, shaped  `(int64)[batch_size, len(VocabulariesFeatures)]`\n",
    "// - continuous inputs, shaped `(float32)[batch_size, len(Quantiles)]`\n",
    "// - weights: not currently used, but shaped `(float32)[batch_size, 1]`.\n",
    "func ModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {\n",
    "\t_ = spec // Not used, since the dataset is always the same.\n",
    "\tg := inputs[0].Graph()\n",
    "\tdtype := inputs[1].DType() // From continuous features.\n",
    "    ctx = ctx.In(\"model\")\n",
    "    \n",
    "\t// Use Cosine schedule of the learning rate, if hyperparameter is set to a value > 0.\n",
    "\tcosineschedule.New(ctx, g, dtype).FromContext().Done()\n",
    "\n",
    "\tcategorical, continuous := inputs[0], inputs[1]\n",
    "\tbatchSize := categorical.Shape().Dimensions[0]\n",
    "\n",
    "\t// Feature preprocessing:\n",
    "\tvar allEmbeddings []*Node\n",
    "\tif *flagUseCategorical {\n",
    "\t\t// Embedding of categorical values, each with its own vocabulary.\n",
    "\t\tnumCategorical := categorical.Shape().Dimensions[1]\n",
    "\t\tfor catIdx := 0; catIdx < numCategorical; catIdx++ {\n",
    "\t\t\t// Take one column at a time of the categorical values.\n",
    "\t\t\tsplit := Slice(categorical, AxisRange(), AxisRange(catIdx, catIdx+1))\n",
    "\t\t\t// Embed it accordingly.\n",
    "\t\t\tembedCtx := ctx.In(fmt.Sprintf(\"categorical_%d_%s\", catIdx, adult.Data.VocabulariesFeatures[catIdx]))\n",
    "\t\t\tvocab := adult.Data.Vocabularies[catIdx]\n",
    "\t\t\tvocabSize := len(vocab)\n",
    "\t\t\tembedding := layers.Embedding(embedCtx, split, ModelDType, vocabSize, *flagEmbeddingDim)\n",
    "\t\t\tembedding.AssertDims(batchSize, *flagEmbeddingDim) // 2-dim tensor, with batch size as the leading dimension.\n",
    "\t\t\tallEmbeddings = append(allEmbeddings, embedding)\n",
    "\t\t}\n",
    "\t}\n",
    "\n",
    "\tif *flagUseContinuous {\n",
    "\t\t// Piecewise-linear calibration of the continuous values. Each feature has its own number of quantiles.\n",
    "\t\tnumContinuous := continuous.Shape().Dimensions[1]\n",
    "\t\tfor contIdx := 0; contIdx < numContinuous; contIdx++ {\n",
    "\t\t\t// Take one column at a time of the continuous values.\n",
    "\t\t\tsplit := Slice(continuous, AxisRange(), AxisRange(contIdx, contIdx+1))\n",
    "\t\t\tfeatureName := adult.Data.QuantilesFeatures[contIdx]\n",
    "\t\t\tcalibrationCtx := ctx.In(fmt.Sprintf(\"continuous_%d_%s\", contIdx, featureName))\n",
    "\t\t\tquantiles := adult.Data.Quantiles[contIdx]\n",
    "\t\t\tlayers.AssertQuantilesForPWLCalibrationValid(quantiles)\n",
    "\t\t\tcalibrated := layers.PieceWiseLinearCalibration(calibrationCtx, split, Const(g, quantiles),\n",
    "\t\t\t\t*flagTrainableCalibration)\n",
    "\t\t\tcalibrated.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.\n",
    "\t\t\tallEmbeddings = append(allEmbeddings, calibrated)\n",
    "\t\t}\n",
    "\t}\n",
    "\tlogits := Concatenate(allEmbeddings, -1)\n",
    "\tlogits.AssertDims(batchSize, -1) // 2-dim tensor, with batch size as the leading dimension (-1 means it is not checked).\n",
    "\n",
    "\t// Model itself is an FNN or a KAN.\n",
    "\tif context.GetParamOr(ctx, \"kan\", false) {\n",
    "\t\t// Use KAN, all configured by context hyperparameters. See createDefaultContext for defaults.\n",
    "\t\tlogits = kan.New(ctx.In(\"kan\"), logits, 1).Done()\n",
    "\t} else {\n",
    "\t\t// Normal FNN, all configured by context hyperparameters. See createDefaultContext for defaults.\n",
    "\t\tlogits = fnn.New(ctx.In(\"fnn\"), logits, 1).Done()\n",
    "\t}\n",
    "\tlogits.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.\n",
    "\treturn []*Node{logits}\n",
    "}\n",
    "\n",
    "%%\n",
    "adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    \n",
    "\n",
    "// Let's just check that we get the right shape from the model function, wihtout any real data.\n",
    "ctx, _ := contextFromSettings()\n",
    "graph := NewGraph(backend, \"test\")\n",
    "batchSize := context.GetParamOr(ctx, \"batch_size\", 128)\n",
    "// Create placeholder (parameters) graph nodes, just to test the graph building is working.\n",
    "inputs := []*Node{\n",
    "    // Categorical: shaped [batch_size, num_categorical]\n",
    "    Parameter(graph, \"categorical\", shapes.Make(dtypes.Int64, batchSize, len(adult.Data.VocabulariesFeatures))),\n",
    "    // Continuous: shaped [batch_size, num_continuos]\n",
    "    Parameter(graph, \"continuous\", shapes.Make(dtypes.Float32, batchSize, len(adult.Data.QuantilesFeatures))),\n",
    "    // Weights: shaped [batch_size, 1]\n",
    "    Parameter(graph, \"weights\", shapes.Make(dtypes.Float32, batchSize, 1)),    \n",
    "}\n",
    "logits := ModelGraph(ctx, nil, inputs)\n",
    "fmt.Printf(\"Logits shape for batch_size=%d: %s\\n\", batchSize, logits[0].Shape())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "618e1c7a-ced1-4914-8530-0924e02f095c",
   "metadata": {},
   "source": [
    "## Training Loop\n",
    "\n",
    "We can create a training loop with only a `Manager`, a `Context` (for the model varibles) and the `ModelGraph` function.\n",
    "\n",
    "To make it more interesting we also add the following:\n",
    "\n",
    "* Accuracy metrics for training and testing.\n",
    "* Checkpoints -- so trained model can be saved, and reloaded.\n",
    "* A progress-bar that also shows training metrics.\n",
    "* We dynamically plot how the loss and accuracy evolve.\n",
    "\n",
    "First we define the corresponding flags and the `trainModel` function, and run it for very few steps to make sure\n",
    "it is working."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "daf39a7c-7cc8-43a5-ae79-026303f5f502",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m 100% [========================================] (223 steps/s)\u001b[0m [step=99] [loss+=0.475] [~loss+=0.561] [~loss=0.56] [~acc=72.50%]        ]        \n",
      "\t[Step 100] median train step: 2561 microseconds\n",
      "Results on batched train:\n",
      "\tMean Loss+Regularization (#loss+): 0.489\n",
      "\tMean Loss (#loss): 0.488\n",
      "\tMean Accuracy (#acc): 76.36%\n",
      "Results on test:\n",
      "\tMean Loss+Regularization (#loss+): 0.483\n",
      "\tMean Loss (#loss): 0.483\n",
      "\tMean Accuracy (#acc): 76.91%\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"fmt\"\n",
    "    \"time\"\n",
    "    \"github.com/gomlx/gomlx/ui/gonb/plotly\"\n",
    "    \"github.com/gomlx/gomlx/pkg/support/fsutil\"\n",
    ")\n",
    "\n",
    "var (\n",
    "    flagCheckpoint     = flag.String(\"checkpoint\", \"\", \"Directory save and load checkpoints from. If left empty, no checkpoints are created.\")\n",
    "    flagPlots          = flag.Bool(\"plots\", true, \"Plots during training: perform periodic evaluations, \"+\n",
    "                                   \"save results if --checkpoint is set and draw plots, if in a Jupyter notebook.\")\n",
    "    flagPlotType       = flag.String(\"plot_type\", \"plotly\", \"Type of plot to use, values are \\\"plotly\\\" or \\\"margaid\\\"\")\n",
    ")\n",
    "\n",
    "func trainModel(ctx *context.Context, paramsSet []string) {\n",
    "    *flagDataDir = must.M1(fsutil.ReplaceTildeInDir(*flagDataDir))\n",
    "\n",
    "    // Load data and create datasets.\n",
    "    adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    \n",
    "    trainDS, trainEvalDS, testEvalDS := BuildDatasets(ctx)\n",
    "\n",
    "\t// Checkpoints loading (and saving)\n",
    "\tvar checkpoint *checkpoints.Handler\n",
    "\tif *flagCheckpoint != \"\" {\n",
    "\t\tnumCheckpointsToKeep := context.GetParamOr(ctx, \"num_checkpoints\", 3)\n",
    "\t\tcheckpoint = must.M1(checkpoints.Build(ctx).\n",
    "\t\t\tDirFromBase(*flagCheckpoint, *flagDataDir).\n",
    "\t\t\tKeep(numCheckpointsToKeep).\n",
    "\t\t\tExcludeParams(append(paramsSet, \"train_steps\", \"plots\", \"num_checkpoints\")...).\n",
    "\t\t\tDone())\n",
    "\t}\n",
    "\n",
    "\t// Metrics we are interested.\n",
    "\tmeanAccuracyMetric := metrics.NewMeanBinaryLogitsAccuracy(\"Mean Accuracy\", \"#acc\")\n",
    "\tmovingAccuracyMetric := metrics.NewMovingAverageBinaryLogitsAccuracy(\"Moving Average Accuracy\", \"~acc\", 0.01)\n",
    "\n",
    "\t// Create a train.Trainer: this object will orchestrate running the model, feeding\n",
    "\t// results to the optimizer, evaluating the metrics, etc. (all happens in trainer.TrainStep)\n",
    "\ttrainer := train.NewTrainer(backend, ctx, ModelGraph, losses.BinaryCrossentropyLogits,\n",
    "\t\toptimizers.FromContext(ctx),\n",
    "\t\t[]metrics.Interface{movingAccuracyMetric}, // trainMetrics\n",
    "\t\t[]metrics.Interface{meanAccuracyMetric})   // evalMetrics\n",
    "\n",
    "    // Use standard training loop.\n",
    "    loop := train.NewLoop(trainer)\n",
    "    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.\n",
    "\n",
    "    // Attach a checkpoint saver: checkpoint every 1 minute of training.\n",
    "    if checkpoint != nil {\n",
    "        period := time.Minute * 1\n",
    "        train.PeriodicCallback(loop, period, true, \"saving checkpoint\", 100,\n",
    "            func(loop *train.Loop, metrics []*tensors.Tensor) error {\n",
    "                fmt.Printf(\"\\n[saving checkpoint@%d] [median train step (ms): %d]\\n\", loop.LoopStep, loop.MedianTrainStepDuration().Milliseconds())\n",
    "                return checkpoint.Save()\n",
    "            })\n",
    "    }\n",
    "\n",
    "\t// Attach Plotly plots: plot points at exponential steps.\n",
    "\t// The points generated are saved along the checkpoint directory (if one is given).\n",
    "\tif context.GetParamOr(ctx, margaid.ParamPlots, false) {\n",
    "\t\t_ = plotly.New().\n",
    "\t\t\tWithCheckpoint(checkpoint).\n",
    "\t\t\tDynamic().\n",
    "\t\t\tWithDatasets(trainEvalDS, testEvalDS).\n",
    "\t\t\tScheduleExponential(loop, 200, 1.2)\n",
    "\t}\n",
    "\n",
    "\t// Train up to \"train_steps\".\n",
    "\tglobalStep := int(optimizers.GetGlobalStep(ctx))\n",
    "\ttrainSteps := context.GetParamOr(ctx, \"train_steps\", 0)\n",
    "\tif globalStep < trainSteps {\n",
    "\t\tif globalStep != 0 {\n",
    "\t\t\tfmt.Printf(\"\\t- restarting training from global_step=%d\\n\", globalStep)\n",
    "            trainer.SetContext(ctx.Reuse())\n",
    "\t\t}\n",
    "\t\t_ = must.M1(loop.RunSteps(trainDS, trainSteps-globalStep))\n",
    "\t\tfmt.Printf(\"\\t[Step %d] median train step: %d microseconds\\n\", loop.LoopStep, loop.MedianTrainStepDuration().Microseconds())\n",
    "\t} else {\n",
    "\t\tfmt.Printf(\"\\t - target train_steps=%d already reached. To train further, set a number larger than \"+\n",
    "\t\t\t\"current global step.\\n\", trainSteps)\n",
    "\t}\n",
    "\n",
    "\t// Finally, print an evaluation on train and test datasets.\n",
    "\tmust.M(commandline.ReportEval(trainer, trainEvalDS, testEvalDS))\n",
    "}\n",
    "\n",
    "// Notice command line flags are passed in the %% notebook command. We set plots=false here to disable plotting\n",
    "// since this is only a quick test that our train() loop is working. See below the final run for a full training.\n",
    "%% -set=\"train_steps=100;plots=false\"\n",
    "trainModel(contextFromSettings())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bb5db6b-3a2c-4f48-8264-f90b22a52ed0",
   "metadata": {},
   "source": [
    "## Final run with 5K steps\n",
    "\n",
    "With everything working, we can do our final run.\n",
    "\n",
    "> **Note** here is where someone might want to hyperparameter tune, trying out different hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a9d901d9-c503-4a94-8dec-a471a18bbbee",
   "metadata": {},
   "outputs": [],
   "source": [
    "// Remove previously trained model -- skip this cell, if you want to continue training.\n",
    "!rm -rf ~/work/uci-adult/base_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "66939463-60d7-446f-b4f3-1d60abe7c17c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m  14% [====>...................................] (172 steps/s) [5s:24s]\u001b[0m [step=724] [loss+=0.255] [~loss+=0.318] [~loss=0.317] [~acc=85.61%]        "
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m 100% [========================================] (182 steps/s)\u001b[0m [step=4999] [loss+=0.252] [~loss+=0.273] [~loss=0.273] [~acc=87.14%]        ]         \n",
      "\n",
      "[saving checkpoint@5000] [median train step (ms): 2]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: accuracy</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"8801fab9\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Accuracy\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.7928646206855774,0.8410997986793518,0.854864239692688,0.8638507127761841,0.8692387938499451,0.8704817295074463,0.8728450536727905,0.8718516826629639,0.8750863075256348,0.8713744878768921]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.8305948972702026,0.8540278077125549,0.8618285655975342,0.8660053610801697,0.8703663945198059,0.8719019889831543,0.8711649179458618,0.8730690479278564,0.873007595539093,0.8747274279594421]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.8299857974052429,0.8549228310585022,0.8610649704933167,0.8611878156661987,0.8684969544410706,0.8686197996139526,0.867084264755249,0.8696639537811279,0.8686811923980713,0.8718751072883606]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"accuracy\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('8801fab9', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: loss</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"c24cc180\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Batch Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.39764603972435,0.33280327916145325,0.3199615776538849,0.3172374665737152,0.27199599146842957,0.3090786933898926,0.29576244950294495,0.23461459577083588,0.206096813082695,0.2516441345214844]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.4513116180896759,0.3578449487686157,0.31888654828071594,0.2991269826889038,0.2893325388431549,0.2828262746334076,0.2786683142185211,0.2747131586074829,0.27006566524505615,0.27325260639190674]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.4507942199707031,0.35732340812683105,0.31836193799972534,0.2985967993736267,0.2887956500053406,0.28228089213371277,0.2781136631965637,0.2741488814353943,0.26949042081832886,0.2726672291755676]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.3978425860404968,0.33474239706993103,0.3067660331726074,0.29345664381980896,0.2842923402786255,0.28020167350769043,0.277349054813385,0.27284327149391174,0.2725278437137604,0.27038607001304626]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.39732229709625244,0.334220290184021,0.3062390983104706,0.29292553663253784,0.28375405073165894,0.27965566515922546,0.2767936587333679,0.2722783088684082,0.27195170521736145,0.26980045437812805]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.394499272108078,0.33290988206863403,0.3095090687274933,0.29824915528297424,0.29070383310317993,0.28719133138656616,0.28529781103134155,0.2822219431400299,0.2822258472442627,0.27995169162750244]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000],\"y\":[0.39397895336151123,0.3323877453804016,0.3089821934700012,0.29771798849105835,0.29016539454460144,0.2866452932357788,0.28474241495132446,0.28165698051452637,0.2816496789455414,0.2793661057949066]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"loss\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('c24cc180', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t[Step 5000] median train step: 2600 microseconds\n",
      "Results on batched train:\n",
      "\tMean Loss+Regularization (#loss+): 0.27\n",
      "\tMean Loss (#loss): 0.27\n",
      "\tMean Accuracy (#acc): 87.47%\n",
      "Results on test:\n",
      "\tMean Loss+Regularization (#loss+): 0.28\n",
      "\tMean Loss (#loss): 0.279\n",
      "\tMean Accuracy (#acc): 87.19%\n"
     ]
    }
   ],
   "source": [
    "%% --checkpoint base_model -set=\"plots=true;train_steps=5000\"\n",
    "trainModel(contextFromSettings())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c38d8a41-3e9c-4107-a507-c879ef27e3e3",
   "metadata": {},
   "source": [
    "## Extend training another 5K steps\n",
    "\n",
    "Since the model training went well, and it doesn't seem to be yet terribly overfiting, \n",
    "let's train further, another 5k steps, for 10K steps in total.\n",
    "\n",
    "Notice the plots continue from where it stopped. And this time we use [Plotly](https://plotly.com/javascript/) to plot the training results -- they don't display in Github since they depend on javascript.\n",
    "\n",
    "Unfortunately, it doesn't help (the accuracy on the test set doesn't improve), 5k steps was already enough."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "95983f61-2a98-442d-bc5d-896006289a2b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_steps=10000\n"
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t- restarting training from global_step=5000\n",
      "      \u001b[1m 100% [========================================] (186 steps/s)\u001b[0m [step=9999] [loss+=0.229] [~loss+=0.268] [~loss=0.268] [~acc=87.45%]        ]         \n",
      "\n",
      "[saving checkpoint@10000] [median train step (ms): 2]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: accuracy</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"e773b05e\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Accuracy\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.7928646206855774,0.8410997986793518,0.854864239692688,0.8638507127761841,0.8692387938499451,0.8704817295074463,0.8728450536727905,0.8718516826629639,0.8750863075256348,0.8713744878768921,0.873005747795105,0.8748104572296143,0.8744657039642334,0.8740350604057312,0.8744916915893555]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.8305948972702026,0.8540278077125549,0.8618285655975342,0.8660053610801697,0.8703663945198059,0.8719019889831543,0.8711649179458618,0.8730690479278564,0.873007595539093,0.8747274279594421,0.8728233575820923,0.8741132020950317,0.8746967315673828,0.8742974996566772,0.8743896484375]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.8299857974052429,0.8549228310585022,0.8610649704933167,0.8611878156661987,0.8684969544410706,0.8686197996139526,0.867084264755249,0.8696639537811279,0.8686811923980713,0.8718751072883606,0.8702781796455383,0.8715680241584778,0.8722436428070068,0.8717522621154785,0.8711994886398315]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"accuracy\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('e773b05e', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: loss</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"3edc484f\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Batch Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.39764603972435,0.33280327916145325,0.3199615776538849,0.3172374665737152,0.27199599146842957,0.3090786933898926,0.29576244950294495,0.23461459577083588,0.206096813082695,0.2516441345214844,0.2696937918663025,0.19027157127857208,0.2063191831111908,0.26013821363449097,0.22858703136444092]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.4513116180896759,0.3578449487686157,0.31888654828071594,0.2991269826889038,0.2893325388431549,0.2828262746334076,0.2786683142185211,0.2747131586074829,0.27006566524505615,0.27325260639190674,0.27277782559394836,0.2677151560783386,0.2697461247444153,0.2689115107059479,0.26822689175605774]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.4507942199707031,0.35732340812683105,0.31836193799972534,0.2985967993736267,0.2887956500053406,0.28228089213371277,0.2781136631965637,0.2741488814353943,0.26949042081832886,0.2726672291755676,0.2721894681453705,0.2671131491661072,0.2691270411014557,0.2682713568210602,0.2675837576389313]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.3978425860404968,0.33474239706993103,0.3067660331726074,0.29345664381980896,0.2842923402786255,0.28020167350769043,0.277349054813385,0.27284327149391174,0.2725278437137604,0.27038607001304626,0.27049896121025085,0.26905763149261475,0.26834332942962646,0.2680197060108185,0.2690286934375763]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.39732229709625244,0.334220290184021,0.3062390983104706,0.29292553663253784,0.28375405073165894,0.27965566515922546,0.2767936587333679,0.2722783088684082,0.27195170521736145,0.26980045437812805,0.2699098289012909,0.2684544324874878,0.2677227556705475,0.2673785090446472,0.2683853805065155]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.394499272108078,0.33290988206863403,0.3095090687274933,0.29824915528297424,0.29070383310317993,0.28719133138656616,0.28529781103134155,0.2822219431400299,0.2822258472442627,0.27995169162750244,0.28032100200653076,0.2798781991004944,0.28123563528060913,0.28146877884864807,0.2812625765800476]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5000,5200,6441,7930,9717,10000],\"y\":[0.39397895336151123,0.3323877453804016,0.3089821934700012,0.29771798849105835,0.29016539454460144,0.2866452932357788,0.28474241495132446,0.28165698051452637,0.2816496789455414,0.2793661057949066,0.2797318994998932,0.2792750895023346,0.2806150019168854,0.2808275818824768,0.2806192934513092]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"loss\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('3edc484f', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t[Step 10000] median train step: 2596 microseconds\n",
      "Results on batched train:\n",
      "\tMean Loss+Regularization (#loss+): 0.269\n",
      "\tMean Loss (#loss): 0.268\n",
      "\tMean Accuracy (#acc): 87.44%\n",
      "Results on test:\n",
      "\tMean Loss+Regularization (#loss+): 0.281\n",
      "\tMean Loss (#loss): 0.281\n",
      "\tMean Accuracy (#acc): 87.12%\n"
     ]
    }
   ],
   "source": [
    "%% --checkpoint base_model -set=\"plots=true;train_steps=10000\"\n",
    "ctx, paramsSet := contextFromSettings()\n",
    "fmt.Printf(\"train_steps=%d\\n\", context.GetParamOr(ctx, \"train_steps\", 0))\n",
    "trainModel(ctx, paramsSet)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06427d06-00d5-4ee5-a408-ddb58621317f",
   "metadata": {},
   "source": [
    "## Using Kolmogorov-Arnold Networks (KAN)\n",
    "\n",
    "Since it's avaialable as a layer (see package `kan`), the model supports it by simply changing a hyperparameter.\n",
    "\n",
    "See description in https://arxiv.org/pdf/2404.19756"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c5d50910-c351-4dda-98cb-5abae8a11ca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "// Remove previously trained model -- skip this cell, if you want to continue training.\n",
    "!rm -rf ~/work/uci-adult/kan_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7ce56c29-0678-4a4f-a028-7a86de57e915",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m   7% [=>......................................] (132 steps/s) [7s:1m10s]\u001b[0m [step=719] [loss+=4.17] [~loss+=4.28] [~loss=0.455] [~acc=79.83%]        "
     ]
    },
    {
     "data": {
      "text/html": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      \u001b[1m 100% [========================================] (182 steps/s)\u001b[0m [step=9999] [loss+=1.4] [~loss+=1.31] [~loss=0.288] [~acc=87.09%]        %]         \n",
      "\n",
      "[saving checkpoint@10000] [median train step (ms): 2]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: accuracy</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"735bb102\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Accuracy\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[0.5356455445289612,0.7363532781600952,0.7976278066635132,0.8205573558807373,0.8305147886276245,0.8370828628540039,0.8442567586898804,0.849527895450592,0.8571056127548218,0.8597533702850342,0.866722822189331,0.8625935912132263,0.8691515326499939,0.8708711862564087]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[0.6727373600006104,0.7744541168212891,0.8117072582244873,0.8211356997489929,0.8354166150093079,0.8422038555145264,0.8475784063339233,0.8537207245826721,0.8546420931816101,0.8617671728134155,0.8667117357254028,0.8694143295288086,0.870949923992157,0.8707042336463928]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Accuracy on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[0.6820219159126282,0.7751365900039673,0.8074442148208618,0.8211411237716675,0.8333025574684143,0.8378477096557617,0.8449726104736328,0.8501319885253906,0.8492106795310974,0.856826901435852,0.8649959564208984,0.8620477318763733,0.8641974925994873,0.8622934222221375]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"accuracy\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('735bb102', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: loss</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"63dda777\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Batch Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[4.7155280113220215,4.397965431213379,4.16745662689209,4.029492378234863,3.801084518432617,3.4952831268310547,3.2545039653778076,2.987250566482544,2.5711283683776855,2.2175393104553223,2.0413482189178467,1.747159481048584,1.3202190399169922,1.4029536247253418]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss+Regularization\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[4.834479808807373,4.536516189575195,4.278083324432373,4.037747859954834,3.7985165119171143,3.5439233779907227,3.264681100845337,2.965831995010376,2.6450424194335938,2.3156611919403076,1.980506420135498,1.6643307209014893,1.3573976755142212,1.3149433135986328]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Train: Moving Average Loss\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[0.7178478240966797,0.5502299666404724,0.45488831400871277,0.4039437174797058,0.3789888322353363,0.36305609345436096,0.34668341279029846,0.33404436707496643,0.3192087709903717,0.3102484345436096,0.30081287026405334,0.3028002679347992,0.2914741635322571,0.2884238362312317]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[4.69100284576416,4.43433141708374,4.197225093841553,3.979001998901367,3.745727062225342,3.488443613052368,3.2171926498413086,2.91913104057312,2.610215425491333,2.2876877784729004,1.9533504247665405,1.6329786777496338,1.3387932777404785,1.3026310205459595]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on batched train\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[0.6225143671035767,0.5049232840538025,0.42999932169914246,0.3986232876777649,0.37605783343315125,0.353549599647522,0.3412187099456787,0.32514721155166626,0.3177129328250885,0.3109897971153259,0.29757702350616455,0.2904336154460907,0.28713876008987427,0.2896663248538971]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss+Regularization on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[4.687947750091553,4.436037540435791,4.198253631591797,3.981177568435669,3.747331142425537,3.4936623573303223,3.2215986251831055,2.924823045730591,2.6151015758514404,2.292271137237549,1.9582329988479614,1.6420408487319946,1.3493889570236206,1.3144968748092651]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"Mean Loss on test\",\"x\":[201,441,729,1075,1490,1988,2586,3304,4166,5200,6441,7930,9717,10000],\"y\":[0.6194599270820618,0.5066288709640503,0.431027889251709,0.4007996618747711,0.3776625394821167,0.35876813530921936,0.34562498331069946,0.33083879947662354,0.3225991427898407,0.315573513507843,0.30245980620384216,0.299495667219162,0.2977348268032074,0.3015323281288147]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"loss\"},\"xaxis\":{\"showgrid\":true,\"title\":{\"text\":\"Steps\"},\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('63dda777', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t[Step 10000] median train step: 2652 microseconds\n",
      "Results on batched train:\n",
      "\tMean Loss+Regularization (#loss+): 1.3\n",
      "\tMean Loss (#loss): 0.29\n",
      "\tMean Accuracy (#acc): 87.07%\n",
      "Results on test:\n",
      "\tMean Loss+Regularization (#loss+): 1.31\n",
      "\tMean Loss (#loss): 0.302\n",
      "\tMean Accuracy (#acc): 86.23%\n"
     ]
    }
   ],
   "source": [
    "%% --checkpoint kan_model -set=\"kan=true;activation=swish;plots=true;train_steps=10_000\"\n",
    "trainModel(contextFromSettings())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Go (gonb)",
   "language": "go",
   "name": "gonb"
  },
  "language_info": {
   "codemirror_mode": "",
   "file_extension": ".go",
   "mimetype": "text/x-go",
   "name": "go",
   "nbconvert_exporter": "",
   "pygments_lexer": "",
   "version": "go1.25.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
